{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a57d6663-4055-40a6-a9fc-dffb38c418b1",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-02T07:17:54.485482Z",
     "iopub.status.busy": "2024-09-02T07:17:54.485266Z",
     "iopub.status.idle": "2024-09-02T07:17:54.488773Z",
     "shell.execute_reply": "2024-09-02T07:17:54.488208Z",
     "shell.execute_reply.started": "2024-09-02T07:17:54.485464Z"
    }
   },
   "outputs": [],
   "source": [
    "%%capture --no-stderr\n",
    "!pip install -U langchain langchain_community langchain_openai pypdf sentence_transformers chromadb shutil openpyxl FlagEmbedding"
   ]
  },
  {
   "cell_type": "code",
   "id": "603ae327-d834-4074-b8f0-5801669a6f26",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-02T07:17:54.490034Z",
     "iopub.status.busy": "2024-09-02T07:17:54.489846Z",
     "iopub.status.idle": "2024-09-02T07:17:54.503323Z",
     "shell.execute_reply": "2024-09-02T07:17:54.502676Z",
     "shell.execute_reply.started": "2024-09-02T07:17:54.490016Z"
    },
    "ExecuteTime": {
     "end_time": "2024-09-13T15:21:12.163191Z",
     "start_time": "2024-09-13T15:21:12.160910Z"
    }
   },
   "source": [
    "%env LLM_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1\n",
    "%env LLM_API_KEY=替换为自己的key"
   ],
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "env: LLM_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1\n",
      "env: LLM_API_KEY=替换为自己的key\n"
     ]
    }
   ],
   "execution_count": 1
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1b6c2792-91bf-464a-bb0d-9724f790c6f9",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-02T07:17:54.504416Z",
     "iopub.status.busy": "2024-09-02T07:17:54.504170Z",
     "iopub.status.idle": "2024-09-02T07:17:58.771665Z",
     "shell.execute_reply": "2024-09-02T07:17:58.771173Z",
     "shell.execute_reply.started": "2024-09-02T07:17:54.504393Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/anaconda3/lib/python3.10/site-packages/sentence_transformers/cross_encoder/CrossEncoder.py:11: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n",
      "  from tqdm.autonotebook import tqdm, trange\n",
      "2024-09-02 15:17:56.333336: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda/lib64:/usr/local/cuda/lib64:\n",
      "2024-09-02 15:17:56.333352: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "langchain                     0.2.10\n",
      "langchain_core                0.2.28\n",
      "langchain_community           0.2.9\n",
      "pypdf                         4.3.1\n",
      "sentence_transformers         3.0.1\n",
      "chromadb                      0.5.4\n"
     ]
    }
   ],
   "source": [
    "import langchain, langchain_community, pypdf, sentence_transformers, chromadb, langchain_core\n",
    "\n",
    "for module in (langchain, langchain_core, langchain_community, pypdf, sentence_transformers, chromadb):\n",
    "    print(f\"{module.__name__:<30}{module.__version__}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "1e2c72b8-ee12-4130-af88-699998aa230c",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-02T07:17:58.772496Z",
     "iopub.status.busy": "2024-09-02T07:17:58.772153Z",
     "iopub.status.idle": "2024-09-02T07:17:58.774847Z",
     "shell.execute_reply": "2024-09-02T07:17:58.774530Z",
     "shell.execute_reply.started": "2024-09-02T07:17:58.772482Z"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "841d2b02-ad06-40d2-b11f-c7adccec6ca2",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-02T07:17:58.776142Z",
     "iopub.status.busy": "2024-09-02T07:17:58.775993Z",
     "iopub.status.idle": "2024-09-02T07:17:58.788582Z",
     "shell.execute_reply": "2024-09-02T07:17:58.788152Z",
     "shell.execute_reply.started": "2024-09-02T07:17:58.776131Z"
    }
   },
   "outputs": [],
   "source": [
    "expr_version = 'retrieval_v6_rerank_ft'\n",
    "\n",
    "preprocess_output_dir = os.path.join(os.path.pardir, 'outputs', 'v1_20240713')\n",
    "expr_dir = os.path.join(os.path.pardir, 'experiments', expr_version)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cf7e81e3-4c82-4842-aef5-7592caaf1d39",
   "metadata": {},
   "source": [
    "# 读取文档"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "e6920e29-bc7d-4635-be06-d151eaf0e100",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-02T07:17:58.789184Z",
     "iopub.status.busy": "2024-09-02T07:17:58.789048Z",
     "iopub.status.idle": "2024-09-02T07:18:00.223585Z",
     "shell.execute_reply": "2024-09-02T07:18:00.223111Z",
     "shell.execute_reply.started": "2024-09-02T07:17:58.789159Z"
    }
   },
   "outputs": [],
   "source": [
    "from langchain_community.document_loaders import PyPDFLoader\n",
    "\n",
    "loader = PyPDFLoader(os.path.join(os.path.pardir, 'data', '2024全球经济金融展望报告.pdf'))\n",
    "documents = loader.load()\n",
    "\n",
    "qa_df = pd.read_excel(os.path.join(preprocess_output_dir, 'question_answer.xlsx'))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "841ec659-4ad7-4e1f-b1ea-3477bf97fde3",
   "metadata": {},
   "source": [
    "# 文档切分"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "74fe856a-7c19-4c3c-bb30-7abfa6298f74",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-02T07:18:00.224276Z",
     "iopub.status.busy": "2024-09-02T07:18:00.224052Z",
     "iopub.status.idle": "2024-09-02T07:18:00.230468Z",
     "shell.execute_reply": "2024-09-02T07:18:00.230131Z",
     "shell.execute_reply.started": "2024-09-02T07:18:00.224263Z"
    }
   },
   "outputs": [],
   "source": [
    "from uuid import uuid4\n",
    "import os\n",
    "import pickle\n",
    "\n",
    "from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
    "\n",
    "def split_docs(documents, filepath, chunk_size=400, chunk_overlap=40, seperators=['\\n\\n\\n', '\\n\\n'], force_split=False):\n",
    "    if os.path.exists(filepath) and not force_split:\n",
    "        print('found cache, restoring...')\n",
    "        return pickle.load(open(filepath, 'rb'))\n",
    "\n",
    "    splitter = RecursiveCharacterTextSplitter(\n",
    "        chunk_size=chunk_size,\n",
    "        chunk_overlap=chunk_overlap,\n",
    "        separators=seperators\n",
    "    )\n",
    "    split_docs = splitter.split_documents(documents)\n",
    "    for chunk in split_docs:\n",
    "        chunk.metadata['uuid'] = str(uuid4())\n",
    "\n",
    "    pickle.dump(split_docs, open(filepath, 'wb'))\n",
    "\n",
    "    return split_docs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "aa25540d-0504-4ae7-9804-9e3862b132d5",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-02T07:18:00.231052Z",
     "iopub.status.busy": "2024-09-02T07:18:00.230885Z",
     "iopub.status.idle": "2024-09-02T07:18:00.244672Z",
     "shell.execute_reply": "2024-09-02T07:18:00.244322Z",
     "shell.execute_reply.started": "2024-09-02T07:18:00.231041Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "found cache, restoring...\n"
     ]
    }
   ],
   "source": [
    "splitted_docs = split_docs(documents, os.path.join(preprocess_output_dir, 'split_docs.pkl'), chunk_size=500, chunk_overlap=50)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "220dbc3a-fceb-4e49-a3f1-01e16660b2a6",
   "metadata": {},
   "source": [
    "# 检索"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "8598a11c-25d8-4af1-a98b-06a8c394e261",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-02T07:18:00.245232Z",
     "iopub.status.busy": "2024-09-02T07:18:00.245110Z",
     "iopub.status.idle": "2024-09-02T07:18:00.260296Z",
     "shell.execute_reply": "2024-09-02T07:18:00.259849Z",
     "shell.execute_reply.started": "2024-09-02T07:18:00.245220Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "device: cuda\n"
     ]
    }
   ],
   "source": [
    "from langchain.embeddings import HuggingFaceBgeEmbeddings\n",
    "import torch\n",
    "\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "print(f'device: {device}')\n",
    "\n",
    "def get_embeddings(model_path):\n",
    "    embeddings = HuggingFaceBgeEmbeddings(\n",
    "        model_name=model_path,\n",
    "        model_kwargs={'device': device},\n",
    "        encode_kwargs={'normalize_embeddings': True},\n",
    "        # show_progress=True\n",
    "        query_instruction='为这个句子生成表示以用于检索相关文章：'\n",
    "    )\n",
    "    return embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "663ef1a4-5866-4f6b-8d9d-4724f62142cb",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-02T07:18:00.260939Z",
     "iopub.status.busy": "2024-09-02T07:18:00.260765Z",
     "iopub.status.idle": "2024-09-02T07:18:15.424640Z",
     "shell.execute_reply": "2024-09-02T07:18:15.424157Z",
     "shell.execute_reply.started": "2024-09-02T07:18:00.260927Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Building prefix dict from the default dictionary ...\n",
      "Loading model from cache /tmp/jieba.cache\n",
      "Loading model cost 0.495 seconds.\n",
      "Prefix dict has been built successfully.\n"
     ]
    }
   ],
   "source": [
    "import jieba\n",
    "import shutil\n",
    "\n",
    "from tqdm.auto import tqdm\n",
    "from langchain_community.vectorstores import Chroma\n",
    "from langchain.retrievers import BM25Retriever, EnsembleRetriever\n",
    "\n",
    "# 如果已下载，可以替换为本机路径\n",
    "model_path = 'stevenluo/bge-large-zh-v1.5-ft-v4'\n",
    "embeddings = get_embeddings(model_path)\n",
    "\n",
    "persist_directory = os.path.join(expr_dir, 'chroma', 'bge')\n",
    "shutil.rmtree(persist_directory, ignore_errors=True)\n",
    "vector_db = Chroma.from_documents(\n",
    "    splitted_docs,\n",
    "    embedding=embeddings,\n",
    "    persist_directory=persist_directory\n",
    ")\n",
    "chz_cut_bm25_retriever = BM25Retriever.from_documents(splitted_docs, preprocess_func=lambda text: list(jieba.cut(text)))\n",
    "\n",
    "def build_get_ensemble_retriver_fn(weights=[0.5, 0.5]):\n",
    "    return lambda k: EnsembleRetriever(\n",
    "        retrievers=[vector_db.as_retriever(search_kwargs={'k': k}), chz_cut_bm25_retriever], weights=weights\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "566c6f3c-5777-4aa9-bc60-a3ee23050506",
   "metadata": {},
   "source": [
    "## 计算检索准确率"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "b03e3382-39e9-4932-a265-69b811041629",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-02T07:18:15.425344Z",
     "iopub.status.busy": "2024-09-02T07:18:15.425177Z",
     "iopub.status.idle": "2024-09-02T07:18:15.428475Z",
     "shell.execute_reply": "2024-09-02T07:18:15.428164Z",
     "shell.execute_reply.started": "2024-09-02T07:18:15.425332Z"
    }
   },
   "outputs": [],
   "source": [
    "test_df = qa_df[(qa_df['dataset'] == 'test') & (qa_df['qa_type'] == 'detailed')]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "32c3ad14-b217-44aa-bdb9-909b9d559668",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-02T07:18:15.429226Z",
     "iopub.status.busy": "2024-09-02T07:18:15.428903Z",
     "iopub.status.idle": "2024-09-02T07:18:15.442418Z",
     "shell.execute_reply": "2024-09-02T07:18:15.441959Z",
     "shell.execute_reply.started": "2024-09-02T07:18:15.429201Z"
    }
   },
   "outputs": [],
   "source": [
    "def get_hit_stat_df(get_retriever_fn, top_k_arr=list(range(1, 9))):\n",
    "    hit_stat_data = []\n",
    "\n",
    "    for k in tqdm(top_k_arr):\n",
    "        retriever = get_retriever_fn(k)\n",
    "        for idx, row in test_df.iterrows():\n",
    "            question = row['question']\n",
    "            true_uuid = row['uuid']\n",
    "            # chunks = retrieve_fn(question, k=k)\n",
    "            chunks = retriever.get_relevant_documents(question)[:k]\n",
    "            retrieved_uuids = [doc.metadata['uuid'] for doc in chunks]\n",
    "\n",
    "            hit_stat_data.append({\n",
    "                'question': question,\n",
    "                'top_k': k,\n",
    "                'hit': int(true_uuid in retrieved_uuids),\n",
    "                'retrieved_chunks': len(chunks)\n",
    "            })\n",
    "    hit_stat_df = pd.DataFrame(hit_stat_data)\n",
    "    return hit_stat_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "a63797c7-4151-4f55-8e5d-080c34265393",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-02T07:18:15.443033Z",
     "iopub.status.busy": "2024-09-02T07:18:15.442885Z",
     "iopub.status.idle": "2024-09-02T07:18:33.502561Z",
     "shell.execute_reply": "2024-09-02T07:18:33.502170Z",
     "shell.execute_reply.started": "2024-09-02T07:18:15.443021Z"
    }
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "31ce87141ebd480385a8424fa08bfeb0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/8 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/anaconda3/lib/python3.10/site-packages/langchain_core/_api/deprecation.py:139: LangChainDeprecationWarning: The method `BaseRetriever.get_relevant_documents` was deprecated in langchain-core 0.1.46 and will be removed in 0.3.0. Use invoke instead.\n",
      "  warn_deprecated(\n"
     ]
    }
   ],
   "source": [
    "retriever_only_hit_stat_df = get_hit_stat_df(build_get_ensemble_retriver_fn(weights=[0.5, 0.5]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "d4890789-a44c-41de-b17f-0ff505788494",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-02T07:18:33.504382Z",
     "iopub.status.busy": "2024-09-02T07:18:33.504251Z",
     "iopub.status.idle": "2024-09-02T07:18:33.510738Z",
     "shell.execute_reply": "2024-09-02T07:18:33.510425Z",
     "shell.execute_reply.started": "2024-09-02T07:18:33.504370Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>top_k</th>\n",
       "      <th>hit_rate</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>0.548387</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2</td>\n",
       "      <td>0.752688</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3</td>\n",
       "      <td>0.870968</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4</td>\n",
       "      <td>0.892473</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>5</td>\n",
       "      <td>0.913978</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>6</td>\n",
       "      <td>0.935484</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>7</td>\n",
       "      <td>0.946237</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>8</td>\n",
       "      <td>0.967742</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   top_k  hit_rate\n",
       "0      1  0.548387\n",
       "1      2  0.752688\n",
       "2      3  0.870968\n",
       "3      4  0.892473\n",
       "4      5  0.913978\n",
       "5      6  0.935484\n",
       "6      7  0.946237\n",
       "7      8  0.967742"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "retriever_only_hit_stat_df.groupby(['top_k'])['hit'].mean().reset_index().rename(columns={'hit': 'hit_rate'})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "b0b086d1-6cec-4743-8df6-2ab3b1593689",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-02T07:18:33.511331Z",
     "iopub.status.busy": "2024-09-02T07:18:33.511208Z",
     "iopub.status.idle": "2024-09-02T07:18:33.853576Z",
     "shell.execute_reply": "2024-09-02T07:18:33.853065Z",
     "shell.execute_reply.started": "2024-09-02T07:18:33.511319Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<Axes: xlabel='top_k', ylabel='hit'>"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjcAAAGxCAYAAACeKZf2AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/TGe4hAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAiiUlEQVR4nO3dfVSUdf7/8deAApqKmXIjomg34i2YJIvmdkex5rE4nW1Zc5XQ7FRQ6PxqFW8gM8XaJD1lkvf+djN12yx3Ncr4heZKoSitbqmZGqwK6umbKCmsM/P7o9Ps8hUNFOYaPjwf58w5zcV1Oe/PelyfXtc1MzaXy+USAACAIXysHgAAAKAxETcAAMAoxA0AADAKcQMAAIxC3AAAAKMQNwAAwCjEDQAAMApxAwAAjNLK6gE8zel06vjx42rfvr1sNpvV4wAAgHpwuVw6e/asunbtKh+fK5+baXFxc/z4cYWHh1s9BgAAuAplZWXq1q3bFfexNG62bdumP/zhDyouLtaJEye0YcMGJSYmXvGYgoIC2e12/fOf/1R4eLhmzJihRx99tN6v2b59e0k//o/ToUOHa5geAAB4SmVlpcLDw91/j1+JpXFTVVWlqKgojR8/Xg899NDP7n/kyBGNHDlSTzzxhN566y3l5+frscceU2hoqBISEur1mj9diurQoQNxAwBAM1OfW0osjZsRI0ZoxIgR9d4/NzdXPXv21Pz58yVJffr00fbt2/Xqq6/WO24AAIDZmtW7pQoLCxUfH19rW0JCggoLCy2aCAAAeJtmdUNxeXm5goODa20LDg5WZWWlzp8/rzZt2lxyTHV1taqrq93PKysrm3xOAABgnWZ15uZqZGdnKzAw0P3gnVIAAJitWcVNSEiIKioqam2rqKhQhw4d6jxrI0kZGRk6c+aM+1FWVuaJUQEAgEWa1WWpuLg4bd68uda2LVu2KC4u7rLH+Pv7y9/fv6lHAwAAXsLSMzfnzp1TSUmJSkpKJP34Vu+SkhKVlpZK+vGsy7hx49z7P/HEEzp8+LB+//vfa//+/XrjjTe0fv16TZ482YrxAQCAF7I0bnbt2qVBgwZp0KBBkiS73a5BgwYpMzNTknTixAl36EhSz549tWnTJm3ZskVRUVGaP3++li1bxtvAAQCAm83lcrmsHsKTKisrFRgYqDNnzvAhfgAANBMN+fu7Wd1QDAAA8HOIGwAAYBTiBgAAGIW4AQAARiFuAACAUYgbAABglGb1CcUAAODynn/+eatHaBTXug7O3AAAAKMQNwAAwCjEDQAAMApxAwAAjELcAAAAoxA3AADAKLwVHABgnK/m/D+rR2gUfabfbfUIzRJnbgAAgFGIGwAAYBTiBgAAGIV7bgDAYHN+92urR2gU0//0jtUjoBnhzA0AADAKcQMAAIxC3AAAAKMQNwAAwCjEDQAAMApxAwAAjELcAAAAo/A5NwBahNf/z1+tHqFRpM0fZfUIgNfjzA0AADAKcQMAAIxC3AAAAKMQNwAAwCjEDQAAMApxAwAAjELcAAAAo/A5N0ALs/WXd1g9QqO4Y9tWq0cA4KU4cwMAAIxC3AAAAKMQNwAAwCjEDQAAMApxAwAAjELcAAAAoxA3AADAKHzODVqsYa8Ns3qERvH3p/9u9QgA4FU4cwMAAIxC3AAAAKMQNwAAwCjEDQAAMApxAwAAjELcAAAAoxA3AADAKMQNAAAwCnEDAACMQtwAAACjEDcAAMAoxA0AADAKcQMAAIxC3AAAAKMQNwAAwCjEDQAAMApxAwAAjELcAAAAoxA3AADAKMQNAAAwiuVxs2jRIkVERCggIECxsbEqKiq64v4LFixQ79691aZNG4WHh2vy5Mm6cOGCh6YFAADeztK4Wbdunex2u7KysrR7925FRUUpISFBJ0+erHP/NWvWaOrUqcrKytJXX32l5cuXa926dZo2bZqHJwcAAN7K0rjJycnRxIkTlZKSor59+yo3N1dt27bVihUr6tx/x44dGjZsmB555BFFRETovvvu0+jRo3/2bA8AAGg5LIubmpoaFRcXKz4+/j/D+PgoPj5ehYWFdR4zdOhQFRcXu2Pm8OHD2rx5s+6//36PzAwAALxfK6te+PTp03I4HAoODq61PTg4WPv376/zmEceeUSnT5/W7bffLpfLpYsXL+qJJ5644mWp6upqVVdXu59XVlY2zgIAAIBXsvyG4oYoKCjQ3Llz9cYbb2j37t169913tWnTJs2ePfuyx2RnZyswMND9CA8P9+DEAADA0yw7c9O5c2f5+vqqoqKi1vaKigqFhITUeczMmTM1duxYPfbYY5KkAQMGqKqqSo8//rimT58uH59LWy0jI0N2u939vLKyksABAMBglp258fPz0+DBg5Wfn+/e5nQ6lZ+fr7i4uDqP+eGHHy4JGF9fX0mSy+Wq8xh/f3916NCh1gMAAJjLsjM3kmS325WcnKyYmBgNGTJECxYsUFVVlVJSUiRJ48aNU1hYmLKzsyVJo0aNUk5OjgYNGqTY2FgdOnRIM2fO1KhRo9yRAwAAWjZL4yYpKUmnTp1SZmamysvLFR0drby8PPdNxqWlpbXO1MyYMUM2m00zZszQsWPH1KVLF40aNUpz5syxagkAAMDLWBo3kpSWlqa0tLQ6f1ZQUFDreatWrZSVlaWsrCwPTAYAAJqjZvVuKQAAgJ9D3AAAAKMQNwAAwCjEDQAAMApxAwAAjELcAAAAoxA3AADAKMQNAAAwCnEDAACMQtwAAACjEDcAAMAoln+3FKxX+sIAq0doFN0z91o9AgDAC3DmBgAAGIW4AQAARiFuAACAUYgbAABgFOIGAAAYhbgBAABGIW4AAIBRiBsAAGAU4gYAABiFuAEAAEYhbgAAgFGIGwAAYBTiBgAAGIW4AQAARiFuAACAUYgbAABgFOIGAAAYhbgBAABGIW4AAIBRiBsAAGAU4gYAABiFuAEAAEYhbgAAgFGIGwAAYBTiBgAAGIW4AQAARiFuAACAUYgbAABgFOIGAAAYhbgBAABGIW4AAIBRiBsAAGAU4gYAABiFuAEAAEYhbgAAgFGIGwAAYBTiBgAAGIW4AQAARiFuAACAUYgbAABgFOIGAAAYhbgBAABGIW4AAIBRiBsAAGAU4gYAABiFuAEAAEYhbgAAgFGIGwAAYBTiBgAAGIW4AQAARrE8bhYtWqSIiAgFBAQoNjZWRUVFV9z/+++/V2pqqkJDQ+Xv769bbrlFmzdv9tC0AADA27Wy8sXXrVsnu92u3NxcxcbGasGCBUpISNCBAwcUFBR0yf41NTW69957FRQUpHfeeUdhYWH69ttv1bFjR88PDwAAvJKlcZOTk6OJEycqJSVFkpSbm6tNmzZpxYoVmjp16iX7r1ixQt9995127Nih1q1bS5IiIiI8OTIAAPByll2WqqmpUXFxseLj4/8zjI+P4uPjVVhYWOcxGzduVFxcnFJTUxUcHKz+/ftr7ty5cjgcl32d6upqVVZW1noAAABzWRY3p0+flsPhUHBwcK3twcHBKi8vr/OYw4cP65133pHD4dDmzZs1c+ZMzZ8/Xy+++OJlXyc7O1uBgYHuR3h4eKOuAwAAeBfLbyhuCKfTqaCgIC1ZskSDBw9WUlKSpk+frtzc3Msek5GRoTNnzrgfZWVlHpwYAAB4mmX33HTu3Fm+vr6qqKiotb2iokIhISF1HhMaGqrWrVvL19fXva1Pnz4qLy9XTU2N/Pz8LjnG399f/v7+jTs8AADwWpadufHz89PgwYOVn5/v3uZ0OpWfn6+4uLg6jxk2bJgOHTokp9Pp3nbw4EGFhobWGTYAAKDlsfSylN1u19KlS7V69Wp99dVXevLJJ1VVVeV+99S4ceOUkZHh3v/JJ5/Ud999p/T0dB08eFCbNm3S3LlzlZqaatUSAACAl7H0reBJSUk6deqUMjMzVV5erujoaOXl5blvMi4tLZWPz3/6Kzw8XB9++KEmT56sgQMHKiwsTOnp6ZoyZYpVSwAAAF7G0riRpLS0NKWlpdX5s4KCgku2xcXF6bPPPmviqQAAQHPVrN4tBQAA8HOIGwAAYBTiBgAAGIW4AQAARiFuAACAUYgbAABgFOIGAAAYhbgBAABGIW4AAIBRiBsAAGAU4gYAABiFuAEAAEax/Iszvcng5/6v1SM0iuI/jLN6BAAALMOZGwAAYBTiBgAAGIW4AQAARiFuAACAUYgbAABgFOIGAAAYhbgBAABGIW4AAIBRripu7r77bn3//feXbK+srNTdd999rTMBAABctauKm4KCAtXU1Fyy/cKFC/r000+veSgAAICr1aCvX/jHP/7h/u8vv/xS5eXl7ucOh0N5eXkKCwtrvOkAAAAaqEFxEx0dLZvNJpvNVuflpzZt2ui1115rtOEAAAAaqkFxc+TIEblcLvXq1UtFRUXq0qWL+2d+fn4KCgqSr69vow8JAABQXw2Kmx49ekiSnE5nkwwDAABwreodNxs3btSIESPUunVrbdy48Yr7PvDAA9c8GAAAwNWod9wkJiaqvLxcQUFBSkxMvOx+NptNDoejMWYDAABosHrHzX9fiuKyFAAA8FYNuufmv+Xn5ys/P18nT56sFTs2m03Lly9vlOEAAAAa6qriZtasWXrhhRcUExOj0NBQ2Wy2xp4LAADgqlxV3OTm5mrVqlUaO3ZsY88DAABwTa7q6xdqamo0dOjQxp4FAADgml1V3Dz22GNas2ZNY88CAABwzep9Wcput7v/2+l0asmSJfr44481cOBAtW7duta+OTk5jTchAABAA9Q7bvbs2VPreXR0tCRp3759tbZzczEAALBSvePmk08+aco5AAAAGsVV3XMDAADgrYgbAABgFOIGAAAYhbgBAABGIW4AAIBRiBsAAGAU4gYAABiFuAEAAEYhbgAAgFGIGwAAYBTiBgAAGIW4AQAARiFuAACAUYgbAABgFOIGAAAYhbgBAABGIW4AAIBRiBsAAGAU4gYAABiFuAEAAEYhbgAAgFGIGwAAYBTiBgAAGIW4AQAARvGKuFm0aJEiIiIUEBCg2NhYFRUV1eu4tWvXymazKTExsWkHBAAAzYblcbNu3TrZ7XZlZWVp9+7dioqKUkJCgk6ePHnF444ePapnn31Ww4cP99CkAACgObA8bnJycjRx4kSlpKSob9++ys3NVdu2bbVixYrLHuNwODRmzBjNmjVLvXr18uC0AADA21kaNzU1NSouLlZ8fLx7m4+Pj+Lj41VYWHjZ41544QUFBQVpwoQJP/sa1dXVqqysrPUAAADmsjRuTp8+LYfDoeDg4Frbg4ODVV5eXucx27dv1/Lly7V06dJ6vUZ2drYCAwPdj/Dw8GueGwAAeC/LL0s1xNmzZzV27FgtXbpUnTt3rtcxGRkZOnPmjPtRVlbWxFMCAAArtbLyxTt37ixfX19VVFTU2l5RUaGQkJBL9v/mm2909OhRjRo1yr3N6XRKklq1aqUDBw7oxhtvrHWMv7+//P39m2B6AADgjSw9c+Pn56fBgwcrPz/fvc3pdCo/P19xcXGX7B8ZGam9e/eqpKTE/XjggQd01113qaSkhEtOAADA2jM3kmS325WcnKyYmBgNGTJECxYsUFVVlVJSUiRJ48aNU1hYmLKzsxUQEKD+/fvXOr5jx46SdMl2AADQMlkeN0lJSTp16pQyMzNVXl6u6Oho5eXluW8yLi0tlY9Ps7o1CAAAWMjyuJGktLQ0paWl1fmzgoKCKx67atWqxh8IAAA0W5wSAQAARiFuAACAUYgbAABgFOIGAAAYhbgBAABGIW4AAIBRiBsAAGAU4gYAABiFuAEAAEYhbgAAgFGIGwAAYBTiBgAAGIW4AQAARiFuAACAUYgbAABgFOIGAAAYhbgBAABGIW4AAIBRiBsAAGAU4gYAABiFuAEAAEYhbgAAgFGIGwAAYBTiBgAAGIW4AQAARiFuAACAUYgbAABgFOIGAAAYhbgBAABGIW4AAIBRiBsAAGAU4gYAABiFuAEAAEYhbgAAgFGIGwAAYBTiBgAAGIW4AQAARiFuAACAUYgbAABgFOIGAAAYhbgBAABGIW4AAIBRiBsAAGAU4gYAABiFuAEAAEYhbgAAgFGIGwAAYBTiBgAAGIW4AQAARiFuAACAUYgbAABgFOIGAAAYhbgBAABGIW4AAIBRiBsAAGAU4gYAABiFuAEAAEYhbgAAgFGIGwAAYBTiBgAAGIW4AQAARvGKuFm0aJEiIiIUEBCg2NhYFRUVXXbfpUuXavjw4br++ut1/fXXKz4+/or7AwCAlsXyuFm3bp3sdruysrK0e/duRUVFKSEhQSdPnqxz/4KCAo0ePVqffPKJCgsLFR4ervvuu0/Hjh3z8OQAAMAbWR43OTk5mjhxolJSUtS3b1/l5uaqbdu2WrFiRZ37v/XWW3rqqacUHR2tyMhILVu2TE6nU/n5+R6eHAAAeCNL46ampkbFxcWKj493b/Px8VF8fLwKCwvr9Wv88MMP+ve//61OnTo11ZgAAKAZaWXli58+fVoOh0PBwcG1tgcHB2v//v31+jWmTJmirl271gqk/1ZdXa3q6mr388rKyqsfGAAAeD3LL0tdi3nz5mnt2rXasGGDAgIC6twnOztbgYGB7kd4eLiHpwQAAJ5kadx07txZvr6+qqioqLW9oqJCISEhVzz2lVde0bx58/TRRx9p4MCBl90vIyNDZ86ccT/KysoaZXYAAOCdLI0bPz8/DR48uNbNwD/dHBwXF3fZ415++WXNnj1beXl5iomJueJr+Pv7q0OHDrUeAADAXJbecyNJdrtdycnJiomJ0ZAhQ7RgwQJVVVUpJSVFkjRu3DiFhYUpOztbkvTSSy8pMzNTa9asUUREhMrLyyVJ7dq1U7t27SxbBwAA8A6Wx01SUpJOnTqlzMxMlZeXKzo6Wnl5ee6bjEtLS+Xj858TTIsXL1ZNTY1+/etf1/p1srKy9Pzzz3tydAAA4IUsjxtJSktLU1paWp0/KygoqPX86NGjTT8QAABotpr1u6UAAAD+N+IGAAAYhbgBAABGIW4AAIBRiBsAAGAU4gYAABiFuAEAAEYhbgAAgFGIGwAAYBTiBgAAGIW4AQAARiFuAACAUYgbAABgFOIGAAAYhbgBAABGIW4AAIBRiBsAAGAU4gYAABiFuAEAAEYhbgAAgFGIGwAAYBTiBgAAGIW4AQAARiFuAACAUYgbAABgFOIGAAAYhbgBAABGIW4AAIBRiBsAAGAU4gYAABiFuAEAAEYhbgAAgFGIGwAAYBTiBgAAGIW4AQAARiFuAACAUYgbAABgFOIGAAAYhbgBAABGIW4AAIBRiBsAAGAU4gYAABiFuAEAAEYhbgAAgFGIGwAAYBTiBgAAGIW4AQAARiFuAACAUYgbAABgFOIGAAAYhbgBAABGIW4AAIBRiBsAAGAU4gYAABiFuAEAAEYhbgAAgFGIGwAAYBTiBgAAGIW4AQAARiFuAACAUYgbAABgFOIGAAAYxSviZtGiRYqIiFBAQIBiY2NVVFR0xf3//Oc/KzIyUgEBARowYIA2b97soUkBAIC3szxu1q1bJ7vdrqysLO3evVtRUVFKSEjQyZMn69x/x44dGj16tCZMmKA9e/YoMTFRiYmJ2rdvn4cnBwAA3sjyuMnJydHEiROVkpKivn37Kjc3V23bttWKFSvq3H/hwoX61a9+peeee059+vTR7Nmzdeutt+r111/38OQAAMAbWRo3NTU1Ki4uVnx8vHubj4+P4uPjVVhYWOcxhYWFtfaXpISEhMvuDwAAWpZWVr746dOn5XA4FBwcXGt7cHCw9u/fX+cx5eXlde5fXl5e5/7V1dWqrq52Pz9z5owkqbKy8pJ9HdXnGzS/t6prbVdy9oKjiSbxrIau++L5i000iWc1dN1VF1vmus9X/9BEk3hWQ9d94d//bqJJPKuh6z53oaqJJvGshq77v/++a87qWvdP21wu188eb2nceEJ2drZmzZp1yfbw8HALpvGMwNeesHoEa2QHWj2BJQKntMx1K7Blrvv3i6yewBovrm+Zv9960eoBrDFv3rzL/uzs2bMK/Jk//5bGTefOneXr66uKiopa2ysqKhQSElLnMSEhIQ3aPyMjQ3a73f3c6XTqu+++0w033CCbzXaNK2iYyspKhYeHq6ysTB06dPDoa1uJdbPuloB1s+6WwMp1u1wunT17Vl27dv3ZfS2NGz8/Pw0ePFj5+flKTEyU9GN85OfnKy0trc5j4uLilJ+fr0mTJrm3bdmyRXFxcXXu7+/vL39//1rbOnbs2BjjX7UOHTq0qD8MP2HdLQvrbllYd8ti1bp/7ozNTyy/LGW325WcnKyYmBgNGTJECxYsUFVVlVJSUiRJ48aNU1hYmLKzsyVJ6enpuuOOOzR//nyNHDlSa9eu1a5du7RkyRIrlwEAALyE5XGTlJSkU6dOKTMzU+Xl5YqOjlZeXp77puHS0lL5+PznTV1Dhw7VmjVrNGPGDE2bNk0333yz3nvvPfXv39+qJQAAAC9iedxIUlpa2mUvQxUUFFyy7eGHH9bDDz/cxFM1Pn9/f2VlZV1ymcx0rJt1twSsm3W3BM1l3TZXfd5TBQAA0ExY/gnFAAAAjYm4AQAARiFuAACAUYgbD9i2bZtGjRqlrl27ymaz6b333rN6JI/Izs7Wbbfdpvbt2ysoKEiJiYk6cOCA1WM1ucWLF2vgwIHuz4GIi4vTBx98YPVYHjdv3jzZbLZan0lloueff142m63WIzIy0uqxPOLYsWP63e9+pxtuuEFt2rTRgAEDtGvXLqvHalIRERGX/H7bbDalpqZaPVqTcjgcmjlzpnr27Kk2bdroxhtv1OzZs+v1VQhW8Ip3S5muqqpKUVFRGj9+vB566CGrx/GYrVu3KjU1VbfddpsuXryoadOm6b777tOXX36p6667zurxmky3bt00b9483XzzzXK5XFq9erUefPBB7dmzR/369bN6PI/YuXOn3nzzTQ0cONDqUTyiX79++vjjj93PW7Uy//9a/+d//kfDhg3TXXfdpQ8++EBdunTR119/reuvv97q0ZrUzp075XD85/v49u3bp3vvvbdZvoO3IV566SUtXrxYq1evVr9+/bRr1y6lpKQoMDBQzzzzjNXjXcL8P4FeYMSIERoxYoTVY3hcXl5ereerVq1SUFCQiouL9ctf/tKiqZreqFGjaj2fM2eOFi9erM8++6xFxM25c+c0ZswYLV26VC++2DK+GKdVq1aX/QoYU7300ksKDw/XypUr3dt69uxp4USe0aVLl1rP582bpxtvvFF33HGHRRN5xo4dO/Tggw9q5MiRkn48g/X222+rqKjI4snqxmUpeMxP38jeqVMniyfxHIfDobVr16qqquqyXxFimtTUVI0cOVLx8fFWj+IxX3/9tbp27apevXppzJgxKi0ttXqkJrdx40bFxMTo4YcfVlBQkAYNGqSlS5daPZZH1dTU6E9/+pPGjx/v8e8q9LShQ4cqPz9fBw8elCR98cUX2r59u9f+w50zN/AIp9OpSZMmadiwYS3i06T37t2ruLg4XbhwQe3atdOGDRvUt29fq8dqcmvXrtXu3bu1c+dOq0fxmNjYWK1atUq9e/fWiRMnNGvWLA0fPlz79u1T+/btrR6vyRw+fFiLFy+W3W7XtGnTtHPnTj3zzDPy8/NTcnKy1eN5xHvvvafvv/9ejz76qNWjNLmpU6eqsrJSkZGR8vX1lcPh0Jw5czRmzBirR6sTcQOPSE1N1b59+7R9+3arR/GI3r17q6SkRGfOnNE777yj5ORkbd261ejAKSsrU3p6urZs2aKAgACrx/GY//6X68CBAxUbG6sePXpo/fr1mjBhgoWTNS2n06mYmBjNnTtXkjRo0CDt27dPubm5LSZuli9frhEjRtTrW6qbu/Xr1+utt97SmjVr1K9fP5WUlGjSpEnq2rWrV/5+Ezdocmlpafrb3/6mbdu2qVu3blaP4xF+fn666aabJEmDBw/Wzp07tXDhQr355psWT9Z0iouLdfLkSd16663ubQ6HQ9u2bdPrr7+u6upq+fr6WjihZ3Ts2FG33HKLDh06ZPUoTSo0NPSSWO/Tp4/+8pe/WDSRZ3377bf6+OOP9e6771o9ikc899xzmjp1qn77299KkgYMGKBvv/1W2dnZxA1aFpfLpaefflobNmxQQUFBi7jZ8HKcTqeqq6utHqNJ3XPPPdq7d2+tbSkpKYqMjNSUKVNaRNhIP95Q/c0332js2LFWj9Kkhg0bdslHOxw8eFA9evSwaCLPWrlypYKCgtw32Jruhx9+qPUl1pLk6+srp9Np0URXRtx4wLlz52r9K+7IkSMqKSlRp06d1L17dwsna1qpqalas2aN3n//fbVv317l5eWSpMDAQLVp08bi6ZpORkaGRowYoe7du+vs2bNas2aNCgoK9OGHH1o9WpNq3779JfdTXXfddbrhhhuMvs/q2Wef1ahRo9SjRw8dP35cWVlZ8vX11ejRo60erUlNnjxZQ4cO1dy5c/Wb3/xGRUVFWrJkiZYsWWL1aE3O6XRq5cqVSk5ObhFv+5d+fBfonDlz1L17d/Xr10979uxRTk6Oxo8fb/VodXOhyX3yyScuSZc8kpOTrR6tSdW1ZkmulStXWj1akxo/fryrR48eLj8/P1eXLl1c99xzj+ujjz6yeixL3HHHHa709HSrx2hSSUlJrtDQUJefn58rLCzMlZSU5Dp06JDVY3nEX//6V1f//v1d/v7+rsjISNeSJUusHskjPvzwQ5ck14EDB6wexWMqKytd6enpru7du7sCAgJcvXr1ck2fPt1VXV1t9Wh14lvBAQCAUficGwAAYBTiBgAAGIW4AQAARiFuAACAUYgbAABgFOIGAAAYhbgBAABGIW4AAIBRiBsALVpERIQWLFhg9RgAGhFxA8Br3HnnnZo0aZLVYwBo5ogbAABgFOIGgFd49NFHtXXrVi1cuFA2m002m01Hjx7V1q1bNWTIEPn7+ys0NFRTp07VxYsX3cfdeeedSktLU1pamgIDA9W5c2fNnDlTV/u1ecuWLVPHjh2Vn5/fWEsD4GHEDQCvsHDhQsXFxWnixIk6ceKETpw4odatW+v+++/Xbbfdpi+++EKLFy/W8uXL9eKLL9Y6dvXq1WrVqpWKioq0cOFC5eTkaNmyZQ2e4eWXX9bUqVP10Ucf6Z577mmspQHwsFZWDwAAkhQYGCg/Pz+1bdtWISEhkqTp06crPDxcr7/+umw2myIjI3X8+HFNmTJFmZmZ8vH58d9n4eHhevXVV2Wz2dS7d2/t3btXr776qiZOnFjv158yZYr++Mc/auvWrerXr1+TrBGAZ3DmBoDX+uqrrxQXFyebzebeNmzYMJ07d07/+te/3Nt+8Ytf1NonLi5OX3/9tRwOR71eZ/78+Vq6dKm2b99O2AAGIG4AtHjDhw+Xw+HQ+vXrrR4FQCMgbgB4DT8/v1pnW/r06aPCwsJaNwf//e9/V/v27dWtWzf3ts8//7zWr/PZZ5/p5ptvlq+vb71ed8iQIfrggw80d+5cvfLKK9e4CgBWI24AeI2IiAh9/vnnOnr0qE6fPq2nnnpKZWVlevrpp7V//369//77ysrKkt1ud99vI0mlpaWy2+06cOCA3n77bb322mtKT09v0GsPHTpUmzdv1qxZs/hQP6CZ44ZiAF7j2WefVXJysvr27avz58/ryJEj2rx5s5577jlFRUWpU6dOmjBhgmbMmFHruHHjxun8+fMaMmSIfH19lZ6erscff7zBr3/77bdr06ZNuv/+++Xr66unn366sZYGwINsrqv9MAgA8AJ33nmnoqOjOdsCwI3LUgAAwCjEDQBjffrpp2rXrt1lHwDMxGUpAMY6f/68jh07dtmf33TTTR6cBoCnEDcAAMAoXJYCAABGIW4AAIBRiBsAAGAU4gYAABiFuAEAAEYhbgAAgFGIGwAAYBTiBgAAGOX/A9ji3fZw1hkGAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import seaborn as sns\n",
    "\n",
    "sns.barplot(x='top_k', y='hit', data=retriever_only_hit_stat_df, errorbar=None)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e6c2e518-ae0c-42e6-b610-9cc68cb68448",
   "metadata": {},
   "source": [
    "## 加Reranker"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "4cf953a8-2d3b-4641-9d1e-da60821a95e7",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-02T07:18:33.854294Z",
     "iopub.status.busy": "2024-09-02T07:18:33.854132Z",
     "iopub.status.idle": "2024-09-02T07:18:35.694721Z",
     "shell.execute_reply": "2024-09-02T07:18:35.694207Z",
     "shell.execute_reply.started": "2024-09-02T07:18:33.854281Z"
    }
   },
   "outputs": [],
   "source": [
    "from FlagEmbedding import FlagReranker\n",
    "\n",
    "rerank_model_path = 'stevenluo/bge-reranker-base-ft-v1'\n",
    "\n",
    "reranker = FlagReranker(rerank_model_path, use_fp16=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "05c6c139-1ed8-4f6e-a00f-d08684e30a84",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-02T07:18:35.695524Z",
     "iopub.status.busy": "2024-09-02T07:18:35.695349Z",
     "iopub.status.idle": "2024-09-02T07:18:35.698810Z",
     "shell.execute_reply": "2024-09-02T07:18:35.698499Z",
     "shell.execute_reply.started": "2024-09-02T07:18:35.695511Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "device(type='cuda')"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "reranker.device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "4938875e-edc0-47b3-992a-0adc8958f9cc",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-02T07:18:35.699470Z",
     "iopub.status.busy": "2024-09-02T07:18:35.699323Z",
     "iopub.status.idle": "2024-09-02T07:18:35.712605Z",
     "shell.execute_reply": "2024-09-02T07:18:35.712139Z",
     "shell.execute_reply.started": "2024-09-02T07:18:35.699458Z"
    }
   },
   "outputs": [],
   "source": [
    "def rerank(reranker, query, retrieved_docs, top_k=5, debug=False):\n",
    "    rerank_scores = reranker.compute_score([[query, doc.page_content] for doc in retrieved_docs])\n",
    "    triads = [(query, doc, score) for doc, score in zip(retrieved_docs, rerank_scores)]\n",
    "    triads = sorted(triads, key=lambda triad: triad[-1], reverse=True)\n",
    "    if debug:\n",
    "        return triads\n",
    "    return [triad[1] for triad in triads][:top_k]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "a52a7737-8041-4802-bc77-a4c4acd6877c",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-02T07:18:35.713145Z",
     "iopub.status.busy": "2024-09-02T07:18:35.713031Z",
     "iopub.status.idle": "2024-09-02T07:21:51.317027Z",
     "shell.execute_reply": "2024-09-02T07:21:51.316692Z",
     "shell.execute_reply.started": "2024-09-02T07:18:35.713133Z"
    }
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "334259bce1824e76a16764bdf0dd58b7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/8 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "get_retriever_fn = build_get_ensemble_retriver_fn(weights=[0.5, 0.5])\n",
    "retirever_multiplier = 3\n",
    "\n",
    "hit_stat_data = []\n",
    "\n",
    "top_k_arr = list(range(1, 9))\n",
    "for k in tqdm(top_k_arr):\n",
    "    retriever = get_retriever_fn(k * retirever_multiplier)\n",
    "    for idx, row in test_df.iterrows():\n",
    "        question = row['question']\n",
    "        true_uuid = row['uuid']\n",
    "        chunks = retriever.get_relevant_documents(question)[:k * retirever_multiplier]\n",
    "        chunks = rerank(reranker, question, chunks, top_k=k)\n",
    "        retrieved_uuids = [doc.metadata['uuid'] for doc in chunks][:k]\n",
    "\n",
    "        hit_stat_data.append({\n",
    "            'question': question,\n",
    "            'top_k': k,\n",
    "            'hit': int(true_uuid in retrieved_uuids)\n",
    "        })\n",
    "with_reranker_hit_stat_df = pd.DataFrame(hit_stat_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "00270548-1b2f-4630-8f60-c523d8aef42b",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-02T07:21:51.317679Z",
     "iopub.status.busy": "2024-09-02T07:21:51.317528Z",
     "iopub.status.idle": "2024-09-02T07:21:51.323483Z",
     "shell.execute_reply": "2024-09-02T07:21:51.323036Z",
     "shell.execute_reply.started": "2024-09-02T07:21:51.317667Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>top_k</th>\n",
       "      <th>hit_rate</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>0.774194</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2</td>\n",
       "      <td>0.870968</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3</td>\n",
       "      <td>0.935484</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4</td>\n",
       "      <td>0.956989</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>5</td>\n",
       "      <td>0.956989</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>6</td>\n",
       "      <td>0.956989</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>7</td>\n",
       "      <td>0.967742</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>8</td>\n",
       "      <td>0.967742</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   top_k  hit_rate\n",
       "0      1  0.774194\n",
       "1      2  0.870968\n",
       "2      3  0.935484\n",
       "3      4  0.956989\n",
       "4      5  0.956989\n",
       "5      6  0.956989\n",
       "6      7  0.967742\n",
       "7      8  0.967742"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "with_reranker_hit_stat_df.groupby(['top_k'])['hit'].mean().reset_index().rename(columns={'hit': 'hit_rate'})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "d3f7752d-ba09-437f-b8c2-d9310ceb3473",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-02T07:21:51.324043Z",
     "iopub.status.busy": "2024-09-02T07:21:51.323923Z",
     "iopub.status.idle": "2024-09-02T07:21:51.337465Z",
     "shell.execute_reply": "2024-09-02T07:21:51.337029Z",
     "shell.execute_reply.started": "2024-09-02T07:21:51.324031Z"
    }
   },
   "outputs": [],
   "source": [
    "retriever_only_hit_stat_df['reranker'] = 'w/o'\n",
    "with_reranker_hit_stat_df['reranker'] = 'w/'\n",
    "hit_stat_df = pd.concat([retriever_only_hit_stat_df, with_reranker_hit_stat_df])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "d2726273-f8ba-4bc6-9cae-933d376b3bf4",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-02T07:21:51.338152Z",
     "iopub.status.busy": "2024-09-02T07:21:51.337986Z",
     "iopub.status.idle": "2024-09-02T07:21:51.452144Z",
     "shell.execute_reply": "2024-09-02T07:21:51.451780Z",
     "shell.execute_reply.started": "2024-09-02T07:21:51.338139Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<Axes: xlabel='top_k', ylabel='hit'>"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjcAAAGxCAYAAACeKZf2AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/TGe4hAAAACXBIWXMAAA9hAAAPYQGoP6dpAAArxElEQVR4nO3df1RUdf7H8deA/PAniAiioZiaqKn4I/2ia7JGsebX1tPZct1Swx+dSlqNk6n5K0tFK5FKV9YfpLW66reyLE0raijLUhH8apuapuHXBPJriaCCDfP9w+NsfEFDZeaOH56Pc+Yc53Iv932ns+vTe+/M2JxOp1MAAACG8LF6AAAAgJpE3AAAAKMQNwAAwCjEDQAAMApxAwAAjELcAAAAoxA3AADAKMQNAAAwSh2rB/C08vJy/fDDD2rYsKFsNpvV4wAAgGpwOp06c+aMmjdvLh+fK5+bqXVx88MPPygyMtLqMQAAwDU4duyYbrrppiuuY2ncfPrpp3rhhReUnZ2tEydOaMOGDRoyZMgVt7Hb7UpOTtbXX3+tyMhITZs2TQ899FC199mwYUNJF1+cRo0aXcf0AADAU4qKihQZGen6e/xKLI2bkpISde3aVaNGjdK99977m+sfOXJEgwYN0iOPPKLVq1crMzNTY8aMUUREhBISEqq1z0uXoho1akTcAABwg6nOLSWWxs3AgQM1cODAaq+fnp6u1q1ba8GCBZKkDh06aNu2bVq4cGG14wYAAJjthnq31Pbt2xUfH19hWUJCgrZv327RRAAAwNvcUDcU5+fnKzw8vMKy8PBwFRUV6dy5c6pbt26lbUpLS1VaWup6XlRU5PY5AQCAdW6ouLkWKSkpmjVr1lVv53A4dOHCBTdMdOPz8/OTr6+v1WMAAFClGypumjVrpoKCggrLCgoK1KhRoyrP2kjSlClTlJyc7Hp+6W7ry3E6ncrPz9fPP/9cIzObKjg4WM2aNeOzggAAXueGipvY2Fht3ry5wrIPP/xQsbGxl90mICBAAQEB1d7HpbAJCwtTvXr1+Mv7/3E6nTp79qwKCwslSRERERZPBABARZbGTXFxsQ4dOuR6fuTIEeXm5iokJEQtW7bUlClTdPz4cb322muSpEceeUSLFi3SU089pVGjRunjjz/W+vXrtWnTphqZx+FwuMKmSZMmNfI7TXTpLFlhYaHCwsK4RAUA8CqWvltq165d6tatm7p16yZJSk5OVrdu3TRjxgxJ0okTJ5SXl+dav3Xr1tq0aZM+/PBDde3aVQsWLNDy5ctr7G3gl+6xqVevXo38PpNdeo24LwkA4G0sPXMTFxcnp9N52Z+vXLmyym1ycnLcOFX1PiCotuM1AgB4qxvqc24AAAB+C3FjgIceeug3v5MLAIDagrgBAABGIW5qQFlZmUe28RRvng0AgN9C3FyDuLg4JSUlacKECQoNDVVCQoL27dungQMHqkGDBgoPD9fw4cN18uTJK24jSampqercubPq16+vyMhIPfbYYyouLnZtt3LlSgUHB2vr1q3q0KGDGjRooD/84Q86ceLEZefbuXOnmjZtqvnz50uSfv75Z40ZM0ZNmzZVo0aNNGDAAO3Zs8e1/jPPPKOYmBgtX75crVu3VmBgYE2/ZAAAeAxxc41WrVolf39/ff7555o3b54GDBigbt26adeuXdqyZYsKCgp0//33X3ab9PR0SZKPj49efvllff3111q1apU+/vhjPfXUUxW2O3v2rF588UW9/vrr+vTTT5WXl6cnn3yyyrk+/vhj3XnnnZozZ44mTZokSbrvvvtUWFio999/X9nZ2erevbvuuOMOnTp1yrXdoUOH9Oabb+qtt95Sbm5uDb5SAAB41g31CcXepF27dnr++eclSbNnz1a3bt00d+5c188zMjIUGRmpgwcP6pZbbqm0zSUTJkxw/TkqKkqzZ8/WI488or/97W+u5RcuXFB6erratGkjSUpKStKzzz5baaYNGzZoxIgRWr58uYYOHSpJ2rZtm3bs2KHCwkLXJzW/+OKLevvtt/XGG2/o4YcflnTxUtRrr72mpk2bXu9LAwC1Qt6znd2+j5Yz9l7V+j0mvuamSf5tQ8MX3L6Pqz3u/4+4uUY9evRw/XnPnj365JNP1KBBg0rrHT582BU3v97mko8++kgpKSnav3+/ioqK9Msvv+j8+fM6e/as64Py6tWr5wob6eJXHlz6+oNLvvrqK7333nt64403Krxzas+ePSouLq70icvnzp3T4cOHXc9btWpF2AAAjEDcXKP69eu7/lxcXKzBgwe77nH5tV9/99Kvt5Gko0eP6j//8z/16KOPas6cOQoJCdG2bds0evRolZWVueLGz8+vwnY2m63Shx+2adNGTZo0UUZGhgYNGuTapri4WBEREbLb7ZVmCw4OvuxsAADcqIibGtC9e3e9+eabioqKUp061X9Js7OzVV5ergULFsjH5+LtT+vXr7+mGUJDQ/XWW28pLi5O999/v9avXy8/Pz91795d+fn5qlOnjqKioq7pdwNXyxtP13sCx+0+3njc8F7cUFwDxo0bp1OnTmnYsGHauXOnDh8+rK1btyoxMVEOh+Oy27Vt21YXLlzQK6+8ou+++06vv/6660bjaxEWFqaPP/5Y+/fv17Bhw/TLL78oPj5esbGxGjJkiD744AMdPXpUX3zxhaZOnapdu3Zd874AAPBWxE0NaN68uT7//HM5HA7ddddd6ty5syZMmKDg4GDXGZmqdO3aVampqZo/f75uvfVWrV69WikpKdc1S7NmzfTxxx9r7969euCBB1ReXq7Nmzfr9ttvV2Jiom655Rb9+c9/1vfff6/w8PDr2hcAAN7I5rzSN1caqKioSEFBQTp9+rQaNWpU4Wfnz5/XkSNH+KyXauC1wpXU1ssUHLf7cNzVY/K7pa709/f/xz03AADjeOYvebfvAteIy1IAAMAoxA0AADAKcQMAAIzCPTeAG3njDYeoXbj3BLURZ24AAIBRiBsAAGAU4gYAABiFuAEAAEYhbgAAgFF4t1Q1eeIdB7+W/cIIt/7+rKwsPfjggzp27Jhb9wMAgKdx5qaWeueddzR48GCrxwAAoMZx5sYA7733nh588EH97//+r3x9fZWbm6tu3bpp0qRJmjdvniRpzJgxOn/+vP7xj39IkjZu3KhFixZJkkpLSzVx4kStXbtWRUVF6tmzpxYuXKjbbrvNsmMCahqf9wLUHpy5MUC/fv105swZ5eTkSLp4ySk0NFR2u921TlZWluLi4iRJX3/9tQoLCzVgwABJ0lNPPaU333xTq1at0u7du9W2bVslJCTo1KlTnj4UAACuG3FjgKCgIMXExLhixm6364knnlBOTo6Ki4t1/PhxHTp0SP3795d08ZJUQkKC/P39VVJSoiVLluiFF17QwIED1bFjRy1btkx169bVihUrLDwqAACuDXFjiP79+8tut8vpdOqzzz7Tvffeqw4dOmjbtm3KyspS8+bN1a5dO0kX4+aee+6RJB0+fFgXLlxQ3759Xb/Lz89PvXr10jfffGPJsQAAcD2458YQcXFxysjI0J49e+Tn56fo6GjFxcXJbrfrp59+cp21OXHihHJycjRo0CCLJwYAwD04c2OIS/fdLFy40BUyl+LGbre77rd599131adPH4WEhEiS2rRpI39/f33++eeu33XhwgXt3LlTHTt29PhxAABwvThzY4jGjRurS5cuWr16tetdULfffrvuv/9+XbhwwRU8GzdudF2SkqT69evr0Ucf1cSJExUSEqKWLVvq+eef19mzZzV69GhLjgUAgOtB3FSTuz9Uryb0799fubm5rrM0ISEh6tixowoKCtS+fXuVlJQoMzNTaWlpFbabN2+eysvLNXz4cJ05c0Y9e/bU1q1b1bhxY88fBAAA14m4MUhaWlqlcMnNzXX9eevWrWrdurXatm1bYZ3AwEC9/PLLevnllz0wJazG570AMB333NQiDRo00Pz5860eAwAAt+LMTS1y1113WT0CAABux5kbAABgFOIGAAAYhbgBAABGIW4AAIBRiBsAAGAU3i0Fj8h7trPb99Fyxt6rWp/PewEAM3HmBgAAGIW4AQAARuGyVDV54rLKr13tJZarlZWVpQcffFDHjh1z634AAPA0ztzUUu+8844GDx5s9RgAANQ44sYA7733noKDg+VwOCRd/LJMm82myZMnu9YZM2aMHnzwQdfzjRs36p577vH4rAAAuBtxY4B+/frpzJkzysnJkXTxklNoaKjsdrtrnaysLMXFxUmSvv76axUWFmrAgAEWTAsAgHsRNwYICgpSTEyMK2bsdrueeOIJ5eTkqLi4WMePH9ehQ4fUv39/SRcvSSUkJMjf39/CqQEAcA/ixhD9+/eX3W6X0+nUZ599pnvvvVcdOnTQtm3blJWVpebNm6tdu3aSLsYNl6QAAKbi3VKGiIuLU0ZGhvbs2SM/Pz9FR0crLi5OdrtdP/30k+uszYkTJ5STk6NBgwZZPDEAAO7BmRtDXLrvZuHCha6QuRQ3drvddb/Nu+++qz59+igkJMTCaQEAcB/ixhCNGzdWly5dtHr1alfI3H777dq9e7cOHjzoCh7eJQUAMB2XparJ3R+qVxP69++v3NxcV9yEhISoY8eOKigoUPv27VVSUqLMzEylpaVZOicAAO7EmRuDpKWlyel0Kjo62rUsNzdXJ06ckCRt3bpVrVu3Vtu2ba0aEQAAtyNuapEGDRpo/vz5Vo8BAIBbcVmqFrnrrrusHgEAALfjzA0AADAKcQMAAIxC3FTB6XRaPYLX4zUCAHgr4uZX/Pz8JElnz561eBLvd+k1uvSaAQDgLbih+Fd8fX0VHByswsJCSVK9evVks9ksnsq7OJ1OnT17VoWFhQoODpavr6/VIwEAUIHlcbN48WK98MILys/PV9euXfXKK6+oV69el10/LS1NS5YsUV5enkJDQ/WnP/1JKSkpCgwMrJF5mjVrJkmuwEHVgoODXa8VAADexNK4WbdunZKTk5Wenq7evXsrLS1NCQkJOnDggMLCwiqtv2bNGk2ePFkZGRnq06ePDh48qIceekg2m02pqak1MpPNZlNERITCwsJ04cKFGvmdpvHz8+OMDQDAa1kaN6mpqRo7dqwSExMlSenp6dq0aZMyMjI0efLkSut/8cUX6tu3r/7yl79IkqKiojRs2DB99dVXNT6br68vf4EDAHADsuyG4rKyMmVnZys+Pv7fw/j4KD4+Xtu3b69ymz59+ig7O1s7duyQJH333XfavHmz7r77bo/MDAAAvJ9lZ25Onjwph8Oh8PDwCsvDw8O1f//+Krf5y1/+opMnT+p3v/udnE6nfvnlFz3yyCN6+umnL7uf0tJSlZaWup4XFRXVzAEAAACvdEO9Fdxut2vu3Ln629/+pt27d+utt97Spk2b9Nxzz112m5SUFAUFBbkekZGRHpwYAAB4mmVnbkJDQ+Xr66uCgoIKywsKCi77Lpzp06dr+PDhGjNmjCSpc+fOKikp0cMPP6ypU6fKx6dyq02ZMkXJycmu50VFRQQOAAAGs+zMjb+/v3r06KHMzEzXsvLycmVmZio2NrbKbc6ePVspYC7d9Hu5T8wNCAhQo0aNKjwAAIC5LH23VHJyskaOHKmePXuqV69eSktLU0lJievdUyNGjFCLFi2UkpIiSRo8eLBSU1PVrVs39e7dW4cOHdL06dM1ePBg3tkEAAAkWRw3Q4cO1Y8//qgZM2YoPz9fMTEx2rJli+sm47y8vApnaqZNmyabzaZp06bp+PHjatq0qQYPHqw5c+ZYdQgAAMDLWP4JxUlJSUpKSqryZ3a7vcLzOnXqaObMmZo5c6YHJgMAADeiG+rdUgAAAL+FuAEAAEYhbgAAgFEsv+emtsl7trPb99Fyxl637wMAAG/FmRsAAGAU4gYAABiFuAEAAEYhbgAAgFGIGwAAYBTiBgAAGIW4AQAARuFzbqAeE19z+z42NHT7LgAAkMSZGwAAYBjiBgAAGIW4AQAARiFuAACAUYgbAABgFOIGAAAYhbgBAABGIW4AAIBRiBsAAGAU4gYAABiFuAEAAEYhbgAAgFGIGwAAYBTiBgAAGIW4AQAARiFuAACAUYgbAABgFOIGAAAYhbgBAABGIW4AAIBRiBsAAGAU4gYAABiFuAEAAEYhbgAAgFGIGwAAYBTiBgAAGIW4AQAARiFuAACAUYgbAABgFOIGAAAYhbgBAABGIW4AAIBRiBsAAGAU4gYAABiFuAEAAEYhbgAAgFGIGwAAYBTiBgAAGIW4AQAARiFuAACAUYgbAABgFOIGAAAYhbgBAABGIW4AAIBRiBsAAGAU4gYAABiFuAEAAEYhbgAAgFGIGwAAYBTiBgAAGMXyuFm8eLGioqIUGBio3r17a8eOHVdc/+eff9a4ceMUERGhgIAA3XLLLdq8ebOHpgUAAN6ujpU7X7dunZKTk5Wenq7evXsrLS1NCQkJOnDggMLCwiqtX1ZWpjvvvFNhYWF644031KJFC33//fcKDg72/PAAAMArWRo3qampGjt2rBITEyVJ6enp2rRpkzIyMjR58uRK62dkZOjUqVP64osv5OfnJ0mKiory5MgAAMDLWXZZqqysTNnZ2YqPj//3MD4+io+P1/bt26vcZuPGjYqNjdW4ceMUHh6uW2+9VXPnzpXD4bjsfkpLS1VUVFThAQAAzGVZ3Jw8eVIOh0Ph4eEVloeHhys/P7/Kbb777ju98cYbcjgc2rx5s6ZPn64FCxZo9uzZl91PSkqKgoKCXI/IyMgaPQ4AAOBdLL+h+GqUl5crLCxMS5cuVY8ePTR06FBNnTpV6enpl91mypQpOn36tOtx7NgxD04MAAA8zbJ7bkJDQ+Xr66uCgoIKywsKCtSsWbMqt4mIiJCfn598fX1dyzp06KD8/HyVlZXJ39+/0jYBAQEKCAio2eEBAIDXsuzMjb+/v3r06KHMzEzXsvLycmVmZio2NrbKbfr27atDhw6pvLzctezgwYOKiIioMmwAAEDtY+llqeTkZC1btkyrVq3SN998o0cffVQlJSWud0+NGDFCU6ZMca3/6KOP6tSpUxo/frwOHjyoTZs2ae7cuRo3bpxVhwAAALyMpW8FHzp0qH788UfNmDFD+fn5iomJ0ZYtW1w3Gefl5cnH59/9FRkZqa1bt+qJJ55Qly5d1KJFC40fP16TJk2y6hAAAICXsTRuJCkpKUlJSUlV/sxut1daFhsbqy+//NLNUwEAgBvVDfVuKQAAgN9C3AAAAKMQNwAAwCjEDQAAMApxAwAAjELcAAAAoxA3AADAKMQNAAAwCnEDAACMQtwAAACjEDcAAMAoxA0AADCK5V+c6U16THzN7fvY0NDtuwAAoFbjzA0AADAKcQMAAIxC3AAAAKMQNwAAwCjEDQAAMApxAwAAjELcAAAAoxA3AADAKNcUNwMGDNDPP/9caXlRUZEGDBhwvTMBAABcs2uKG7vdrrKyskrLz58/r88+++y6hwIAALhWV/X1C//93//t+vO//vUv5efnu547HA5t2bJFLVq0qLnpAAAArtJVxU1MTIxsNptsNluVl5/q1q2rV155pcaGAwAAuFpXFTdHjhyR0+nUzTffrB07dqhp06aun/n7+yssLEy+vr41PiQAAEB1XVXctGrVSpJUXl7ulmEAAACuV7XjZuPGjRo4cKD8/Py0cePGK657zz33XPdgAAAA16LacTNkyBDl5+crLCxMQ4YMuex6NptNDoejJmYDAAC4atWOm19fiuKyFAAA8FZXdc/Nr2VmZiozM1OFhYUVYsdms2nFihU1MhwAAMDVuqa4mTVrlp599ln17NlTERERstlsNT0XAADANbmmuElPT9fKlSs1fPjwmp4HAADgulzT1y+UlZWpT58+NT0LAADAdbumuBkzZozWrFlT07MAAABct2pflkpOTnb9uby8XEuXLtVHH32kLl26yM/Pr8K6qampNTchAADAVah23OTk5FR4HhMTI0nat29fheXcXAwAAKxU7bj55JNP3DkHAABAjbime24AAAC8FXEDAACMQtwAAACjEDcAAMAoxA0AADAKcQMAAIxC3AAAAKMQNwAAwCjEDQAAMApxAwAAjELcAAAAoxA3AADAKMQNAAAwCnEDAACMQtwAAACjEDcAAMAoxA0AADAKcQMAAIxC3AAAAKMQNwAAwCjEDQAAMApxAwAAjELcAAAAoxA3AADAKF4RN4sXL1ZUVJQCAwPVu3dv7dixo1rbrV27VjabTUOGDHHvgAAA4IZhedysW7dOycnJmjlzpnbv3q2uXbsqISFBhYWFV9zu6NGjevLJJ9WvXz8PTQoAAG4ElsdNamqqxo4dq8TERHXs2FHp6emqV6+eMjIyLruNw+HQAw88oFmzZunmm2/24LQAAMDbWRo3ZWVlys7OVnx8vGuZj4+P4uPjtX379stu9+yzzyosLEyjR4/+zX2UlpaqqKiowgMAAJjL0rg5efKkHA6HwsPDKywPDw9Xfn5+ldts27ZNK1as0LJly6q1j5SUFAUFBbkekZGR1z03AADwXpZflroaZ86c0fDhw7Vs2TKFhoZWa5spU6bo9OnTrsexY8fcPCUAALBSHSt3HhoaKl9fXxUUFFRYXlBQoGbNmlVa//Dhwzp69KgGDx7sWlZeXi5JqlOnjg4cOKA2bdpU2CYgIEABAQFumB4AAHgjS8/c+Pv7q0ePHsrMzHQtKy8vV2ZmpmJjYyutHx0drb179yo3N9f1uOeee/T73/9eubm5XHICAADWnrmRpOTkZI0cOVI9e/ZUr169lJaWppKSEiUmJkqSRowYoRYtWiglJUWBgYG69dZbK2wfHBwsSZWWAwCA2snyuBk6dKh+/PFHzZgxQ/n5+YqJidGWLVtcNxnn5eXJx+eGujUIAABYyPK4kaSkpCQlJSVV+TO73X7FbVeuXFnzAwEAgBsWp0QAAIBRiBsAAGAU4gYAABiFuAEAAEYhbgAAgFGIGwAAYBTiBgAAGIW4AQAARiFuAACAUYgbAABgFOIGAAAYhbgBAABGIW4AAIBRiBsAAGAU4gYAABiFuAEAAEYhbgAAgFGIGwAAYBTiBgAAGIW4AQAARiFuAACAUYgbAABgFOIGAAAYhbgBAABGIW4AAIBRiBsAAGAU4gYAABiFuAEAAEYhbgAAgFGIGwAAYBTiBgAAGIW4AQAARiFuAACAUYgbAABgFOIGAAAYhbgBAABGIW4AAIBRiBsAAGAU4gYAABiFuAEAAEYhbgAAgFGIGwAAYBTiBgAAGIW4AQAARiFuAACAUYgbAABgFOIGAAAYhbgBAABGIW4AAIBRiBsAAGAU4gYAABiFuAEAAEYhbgAAgFGIGwAAYBTiBgAAGIW4AQAARiFuAACAUYgbAABgFOIGAAAYhbgBAABGIW4AAIBRvCJuFi9erKioKAUGBqp3797asWPHZdddtmyZ+vXrp8aNG6tx48aKj4+/4voAAKB2sTxu1q1bp+TkZM2cOVO7d+9W165dlZCQoMLCwirXt9vtGjZsmD755BNt375dkZGRuuuuu3T8+HEPTw4AALyR5XGTmpqqsWPHKjExUR07dlR6errq1aunjIyMKtdfvXq1HnvsMcXExCg6OlrLly9XeXm5MjMzPTw5AADwRpbGTVlZmbKzsxUfH+9a5uPjo/j4eG3fvr1av+Ps2bO6cOGCQkJC3DUmAAC4gdSxcucnT56Uw+FQeHh4heXh4eHav39/tX7HpEmT1Lx58wqB9GulpaUqLS11PS8qKrr2gQEAgNez/LLU9Zg3b57Wrl2rDRs2KDAwsMp1UlJSFBQU5HpERkZ6eEoAAOBJlsZNaGiofH19VVBQUGF5QUGBmjVrdsVtX3zxRc2bN08ffPCBunTpctn1pkyZotOnT7sex44dq5HZAQCAd7I0bvz9/dWjR48KNwNfujk4Njb2sts9//zzeu6557Rlyxb17NnzivsICAhQo0aNKjwAAIC5LL3nRpKSk5M1cuRI9ezZU7169VJaWppKSkqUmJgoSRoxYoRatGihlJQUSdL8+fM1Y8YMrVmzRlFRUcrPz5ckNWjQQA0aNLDsOAAAgHewPG6GDh2qH3/8UTNmzFB+fr5iYmK0ZcsW103GeXl58vH59wmmJUuWqKysTH/6058q/J6ZM2fqmWee8eToAADAC1keN5KUlJSkpKSkKn9mt9srPD969Kj7BwIAADesG/rdUgAAAP8fcQMAAIxC3AAAAKMQNwAAwCjEDQAAMApxAwAAjELcAAAAoxA3AADAKMQNAAAwCnEDAACMQtwAAACjEDcAAMAoxA0AADAKcQMAAIxC3AAAAKMQNwAAwCjEDQAAMApxAwAAjELcAAAAoxA3AADAKMQNAAAwCnEDAACMQtwAAACjEDcAAMAoxA0AADAKcQMAAIxC3AAAAKMQNwAAwCjEDQAAMApxAwAAjELcAAAAoxA3AADAKMQNAAAwCnEDAACMQtwAAACjEDcAAMAoxA0AADAKcQMAAIxC3AAAAKMQNwAAwCjEDQAAMApxAwAAjELcAAAAoxA3AADAKMQNAAAwCnEDAACMQtwAAACjEDcAAMAoxA0AADAKcQMAAIxC3AAAAKMQNwAAwCjEDQAAMApxAwAAjELcAAAAoxA3AADAKMQNAAAwCnEDAACMQtwAAACjEDcAAMAoxA0AADAKcQMAAIziFXGzePFiRUVFKTAwUL1799aOHTuuuP5//dd/KTo6WoGBgercubM2b97soUkBAIC3szxu1q1bp+TkZM2cOVO7d+9W165dlZCQoMLCwirX/+KLLzRs2DCNHj1aOTk5GjJkiIYMGaJ9+/Z5eHIAAOCNLI+b1NRUjR07VomJierYsaPS09NVr149ZWRkVLn+Sy+9pD/84Q+aOHGiOnTooOeee07du3fXokWLPDw5AADwRpbGTVlZmbKzsxUfH+9a5uPjo/j4eG3fvr3KbbZv315hfUlKSEi47PoAAKB2qWPlzk+ePCmHw6Hw8PAKy8PDw7V///4qt8nPz69y/fz8/CrXLy0tVWlpqev56dOnJUlFRUWV1nWUnruq+a/FGT+H2/dR1bFdCcftPhy3+3Dc1cNxuw/H7T5VHfelZU6n8ze3tzRuPCElJUWzZs2qtDwyMtKCaaRbPbGTlCBP7OWqcNxuxHF7DY7bjThur2H1cZ85c0ZBQVd+XSyNm9DQUPn6+qqgoKDC8oKCAjVr1qzKbZo1a3ZV60+ZMkXJycmu5+Xl5Tp16pSaNGkim812nUdwdYqKihQZGaljx46pUaNGHt23lThujrs24Lg57trAyuN2Op06c+aMmjdv/pvrWho3/v7+6tGjhzIzMzVkyBBJF+MjMzNTSUlJVW4TGxurzMxMTZgwwbXsww8/VGxsbJXrBwQEKCAgoMKy4ODgmhj/mjVq1KhW/Y/hEo67duG4axeOu3ax6rh/64zNJZZflkpOTtbIkSPVs2dP9erVS2lpaSopKVFiYqIkacSIEWrRooVSUlIkSePHj1f//v21YMECDRo0SGvXrtWuXbu0dOlSKw8DAAB4CcvjZujQofrxxx81Y8YM5efnKyYmRlu2bHHdNJyXlycfn3+/qatPnz5as2aNpk2bpqefflrt2rXT22+/rVtv9chVQAAA4OUsjxtJSkpKuuxlKLvdXmnZfffdp/vuu8/NU9W8gIAAzZw5s9JlMtNx3Bx3bcBxc9y1wY1y3DZndd5TBQAAcIOw/BOKAQAAahJxAwAAjELcAAAAoxA3HvDpp59q8ODBat68uWw2m95++22rR/KIlJQU3XbbbWrYsKHCwsI0ZMgQHThwwOqx3G7JkiXq0qWL63MgYmNj9f7771s9lsfNmzdPNputwmdSmeiZZ56RzWar8IiOjrZ6LI84fvy4HnzwQTVp0kR169ZV586dtWvXLqvHcquoqKhK/71tNpvGjRtn9Whu5XA4NH36dLVu3Vp169ZVmzZt9Nxzz1XrqxCs4BXvljJdSUmJunbtqlGjRunee++1ehyPycrK0rhx43Tbbbfpl19+0dNPP6277rpL//rXv1S/fn2rx3Obm266SfPmzVO7du3kdDq1atUq/fGPf1ROTo46depk9XgesXPnTv39739Xly5drB7FIzp16qSPPvrI9bxOHfP/r/Wnn35S37599fvf/17vv/++mjZtqm+//VaNGze2ejS32rlzpxyOf3+30r59+3TnnXfekO/gvRrz58/XkiVLtGrVKnXq1Em7du1SYmKigoKC9Ne//tXq8Sox/3+BXmDgwIEaOHCg1WN43JYtWyo8X7lypcLCwpSdna3bb7/doqncb/DgwRWez5kzR0uWLNGXX35ZK+KmuLhYDzzwgJYtW6bZs2dbPY5H1KlT57JfAWOq+fPnKzIyUq+++qprWevWrS2cyDOaNm1a4fm8efPUpk0b9e/f36KJPOOLL77QH//4Rw0aNEjSxTNY//znP7Vjxw6LJ6sal6XgMZe+kT0kJMTiSTzH4XBo7dq1KikpuexXhJhm3LhxGjRokOLj460exWO+/fZbNW/eXDfffLMeeOAB5eXlWT2S223cuFE9e/bUfffdp7CwMHXr1k3Lli2zeiyPKisr0z/+8Q+NGjXK499V6Gl9+vRRZmamDh48KEnas2ePtm3b5rX/cOfMDTyivLxcEyZMUN++fWvFp0nv3btXsbGxOn/+vBo0aKANGzaoY8eOVo/ldmvXrtXu3bu1c+dOq0fxmN69e2vlypVq3769Tpw4oVmzZqlfv37at2+fGjZsaPV4bvPdd99pyZIlSk5O1tNPP62dO3fqr3/9q/z9/TVy5Eirx/OIt99+Wz///LMeeughq0dxu8mTJ6uoqEjR0dHy9fWVw+HQnDlz9MADD1g9WpWIG3jEuHHjtG/fPm3bts3qUTyiffv2ys3N1enTp/XGG29o5MiRysrKMjpwjh07pvHjx+vDDz9UYGCg1eN4zK//5dqlSxf17t1brVq10vr16zV69GgLJ3Ov8vJy9ezZU3PnzpUkdevWTfv27VN6enqtiZsVK1Zo4MCB1fqW6hvd+vXrtXr1aq1Zs0adOnVSbm6uJkyYoObNm3vlf2/iBm6XlJSk9957T59++qluuukmq8fxCH9/f7Vt21aS1KNHD+3cuVMvvfSS/v73v1s8mftkZ2ersLBQ3bt3dy1zOBz69NNPtWjRIpWWlsrX19fCCT0jODhYt9xyiw4dOmT1KG4VERFRKdY7dOigN99806KJPOv777/XRx99pLfeesvqUTxi4sSJmjx5sv785z9Lkjp37qzvv/9eKSkpxA1qF6fTqccff1wbNmyQ3W6vFTcbXk55eblKS0utHsOt7rjjDu3du7fCssTEREVHR2vSpEm1ImykizdUHz58WMOHD7d6FLfq27dvpY92OHjwoFq1amXRRJ716quvKiwszHWDrenOnj1b4UusJcnX11fl5eUWTXRlxI0HFBcXV/hX3JEjR5Sbm6uQkBC1bNnSwsnca9y4cVqzZo3eeecdNWzYUPn5+ZKkoKAg1a1b1+Lp3GfKlCkaOHCgWrZsqTNnzmjNmjWy2+3aunWr1aO5VcOGDSvdT1W/fn01adLE6PusnnzySQ0ePFitWrXSDz/8oJkzZ8rX11fDhg2zejS3euKJJ9SnTx/NnTtX999/v3bs2KGlS5dq6dKlVo/mduXl5Xr11Vc1cuTIWvG2f+niu0DnzJmjli1bqlOnTsrJyVFqaqpGjRpl9WhVc8LtPvnkE6ekSo+RI0daPZpbVXXMkpyvvvqq1aO51ahRo5ytWrVy+vv7O5s2beq84447nB988IHVY1mif//+zvHjx1s9hlsNHTrUGRER4fT393e2aNHCOXToUOehQ4esHssj3n33Xeett97qDAgIcEZHRzuXLl1q9UgesXXrVqck54EDB6wexWOKioqc48ePd7Zs2dIZGBjovPnmm51Tp051lpaWWj1alfhWcAAAYBQ+5wYAABiFuAEAAEYhbgAAgFGIGwAAYBTiBgAAGIW4AQAARiFuAACAUYgbAABgFOIGQK0WFRWltLQ0q8cAUIOIGwBeIy4uThMmTLB6DAA3OOIGAAAYhbgB4BUeeughZWVl6aWXXpLNZpPNZtPRo0eVlZWlXr16KSAgQBEREZo8ebJ++eUX13ZxcXFKSkpSUlKSgoKCFBoaqunTp+tavzZv+fLlCg4OVmZmZk0dGgAPI24AeIWXXnpJsbGxGjt2rE6cOKETJ07Iz89Pd999t2677Tbt2bNHS5Ys0YoVKzR79uwK265atUp16tTRjh079NJLLyk1NVXLly+/6hmef/55TZ48WR988IHuuOOOmjo0AB5Wx+oBAECSgoKC5O/vr3r16qlZs2aSpKlTpyoyMlKLFi2SzWZTdHS0fvjhB02aNEkzZsyQj8/Ff59FRkZq4cKFstlsat++vfbu3auFCxdq7Nix1d7/pEmT9PrrrysrK0udOnVyyzEC8AzO3ADwWt98841iY2Nls9lcy/r27avi4mL9z//8j2vZf/zHf1RYJzY2Vt9++60cDke19rNgwQItW7ZM27ZtI2wAAxA3AGq9fv36yeFwaP369VaPAqAGEDcAvIa/v3+Fsy0dOnTQ9u3bK9wc/Pnnn6thw4a66aabXMu++uqrCr/nyy+/VLt27eTr61ut/fbq1Uvvv/++5s6dqxdffPE6jwKA1YgbAF4jKipKX331lY4ePaqTJ0/qscce07Fjx/T4449r//79eueddzRz5kwlJye77reRpLy8PCUnJ+vAgQP65z//qVdeeUXjx4+/qn336dNHmzdv1qxZs/hQP+AGxw3FALzGk08+qZEjR6pjx446d+6cjhw5os2bN2vixInq2rWrQkJCNHr0aE2bNq3CdiNGjNC5c+fUq1cv+fr6avz48Xr44Yevev+/+93vtGnTJt19993y9fXV448/XlOHBsCDbM5r/TAIAPACcXFxiomJ4WwLABcuSwEAAKMQNwCM9dlnn6lBgwaXfQAwE5elABjr3LlzOn78+GV/3rZtWw9OA8BTiBsAAGAULksBAACjEDcAAMAoxA0AADAKcQMAAIxC3AAAAKMQNwAAwCjEDQAAMApxAwAAjPJ/Oy9zODG79JEAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import seaborn as sns\n",
    "\n",
    "sns.barplot(x='top_k', y='hit', hue='reranker', data=hit_stat_df, errorbar=None)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7925564a-7d30-4914-baaf-4a00abb7686d",
   "metadata": {},
   "source": [
    "# 生成答案"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "badab849-5b3d-4ed7-8ecf-ea77fbc0b8f9",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-02T07:21:51.452814Z",
     "iopub.status.busy": "2024-09-02T07:21:51.452643Z",
     "iopub.status.idle": "2024-09-02T07:21:51.462265Z",
     "shell.execute_reply": "2024-09-02T07:21:51.461878Z",
     "shell.execute_reply.started": "2024-09-02T07:21:51.452801Z"
    }
   },
   "outputs": [],
   "source": [
    "from langchain.llms import Ollama\n",
    "\n",
    "llm = Ollama(\n",
    "    model='qwen2:7b-instruct',\n",
    "    base_url='http://localhost:11434',\n",
    "    temperature=0\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "50404beb-3be0-4aaa-b124-8c7a52b84531",
   "metadata": {
    "editable": true,
    "execution": {
     "iopub.execute_input": "2024-09-02T07:21:51.462875Z",
     "iopub.status.busy": "2024-09-02T07:21:51.462733Z",
     "iopub.status.idle": "2024-09-02T07:21:51.473017Z",
     "shell.execute_reply": "2024-09-02T07:21:51.472562Z",
     "shell.execute_reply.started": "2024-09-02T07:21:51.462863Z"
    },
    "slideshow": {
     "slide_type": ""
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def rag(retriever, query, n_chunks=4, retirever_multiplier=3):\n",
    "    prompt_tmpl = \"\"\"\n",
    "你是一个金融分析师，擅长根据所获取的信息片段，对问题进行分析和推理。\n",
    "你的任务是根据所获取的信息片段（<<<<context>>><<<</context>>>之间的内容）回答问题。\n",
    "回答保持简洁，不必重复问题，不要添加描述性解释和与答案无关的任何内容。\n",
    "已知信息：\n",
    "<<<<context>>>\n",
    "{{knowledge}}\n",
    "<<<</context>>>\n",
    "\n",
    "问题：{{query}}\n",
    "请回答：\n",
    "\"\"\".strip()\n",
    "\n",
    "    chunks = retriever.get_relevant_documents(query)[:n_chunks * retirever_multiplier]\n",
    "    chunks = rerank(reranker, query, chunks, top_k=n_chunks)\n",
    "    prompt = prompt_tmpl.replace('{{knowledge}}', '\\n\\n'.join([doc.page_content for doc in chunks])).replace('{{query}}', query)\n",
    "\n",
    "    return llm(prompt), chunks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "b727d127-5383-4089-b763-28e2c23aec3a",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-02T07:21:51.473619Z",
     "iopub.status.busy": "2024-09-02T07:21:51.473498Z",
     "iopub.status.idle": "2024-09-02T07:21:55.082725Z",
     "shell.execute_reply": "2024-09-02T07:21:55.082231Z",
     "shell.execute_reply.started": "2024-09-02T07:21:51.473607Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/anaconda3/lib/python3.10/site-packages/langchain_core/_api/deprecation.py:139: LangChainDeprecationWarning: The method `BaseLLM.__call__` was deprecated in langchain-core 0.1.7 and will be removed in 0.3.0. Use invoke instead.\n",
      "  warn_deprecated(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023年10月美国ISM制造业PMI指数较上个月大幅下降了2.3个百分点。\n"
     ]
    }
   ],
   "source": [
    "n_chunks = 3\n",
    "\n",
    "get_retriever_fn = build_get_ensemble_retriver_fn(weights=[0.5, 0.5])\n",
    "retriever = get_retriever_fn(n_chunks)\n",
    "print(rag(retriever, '2023年10月美国ISM制造业PMI指数较上月有何变化？')[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "95e5a804-2dc6-411c-ba71-6ccf765b2b73",
   "metadata": {},
   "source": [
    "## 预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "166392d8-f801-4372-b8ad-3e79aef0b350",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-02T07:21:55.083363Z",
     "iopub.status.busy": "2024-09-02T07:21:55.083230Z",
     "iopub.status.idle": "2024-09-02T07:21:55.088001Z",
     "shell.execute_reply": "2024-09-02T07:21:55.087668Z",
     "shell.execute_reply.started": "2024-09-02T07:21:55.083351Z"
    }
   },
   "outputs": [],
   "source": [
    "prediction_df = qa_df[qa_df['dataset'] == 'test'][['uuid', 'question', 'qa_type', 'answer']].rename(columns={'answer': 'ref_answer'})\n",
    "\n",
    "def predict(prediction_df, retriever, n_chunks):\n",
    "    prediction_df = prediction_df.copy()\n",
    "    answer_dict = {}\n",
    "\n",
    "    for idx, row in tqdm(prediction_df.iterrows(), total=len(prediction_df)):\n",
    "        uuid = row['uuid']\n",
    "        question = row['question']\n",
    "        answer, chunks = rag(retriever, question, n_chunks=n_chunks)\n",
    "        assert len(chunks) <= n_chunks\n",
    "        answer_dict[question] = {\n",
    "            'uuid': uuid,\n",
    "            'ref_answer': row['ref_answer'],\n",
    "            'gen_answer': answer,\n",
    "            'chunks': chunks\n",
    "        }\n",
    "    prediction_df.loc[:, 'gen_answer'] = prediction_df['question'].apply(lambda q: answer_dict[q]['gen_answer'])\n",
    "    prediction_df.loc[:, 'chunks'] = prediction_df['question'].apply(lambda q: answer_dict[q]['chunks'])\n",
    "\n",
    "    return prediction_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "ca46d5f1-e698-457d-abb6-92d83cd59c66",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-02T07:21:55.088558Z",
     "iopub.status.busy": "2024-09-02T07:21:55.088428Z",
     "iopub.status.idle": "2024-09-02T07:26:02.026910Z",
     "shell.execute_reply": "2024-09-02T07:26:02.026522Z",
     "shell.execute_reply.started": "2024-09-02T07:21:55.088546Z"
    }
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a1addb53b4c547c9840eab6167e8edb4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "pred_df = predict(prediction_df, retriever, n_chunks=n_chunks)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6f45bd41-43a4-4d11-bf1f-b16d9693bfa2",
   "metadata": {},
   "source": [
    "# 评估"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "217568fe-c0e4-49eb-9a7c-9fdfbc033d8a",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-02T07:26:02.027577Z",
     "iopub.status.busy": "2024-09-02T07:26:02.027451Z",
     "iopub.status.idle": "2024-09-02T07:26:02.204368Z",
     "shell.execute_reply": "2024-09-02T07:26:02.203904Z",
     "shell.execute_reply.started": "2024-09-02T07:26:02.027565Z"
    }
   },
   "outputs": [],
   "source": [
    "from langchain_openai import ChatOpenAI\n",
    "\n",
    "judge_llm = ChatOpenAI(\n",
    "    api_key=os.environ['LLM_API_KEY'],\n",
    "    base_url=os.environ['LLM_BASE_URL'],\n",
    "    model_name='qwen2-72b-instruct',\n",
    "    temperature=0\n",
    ")\n",
    "\n",
    "import time\n",
    "\n",
    "def evaluate(prediction_df):\n",
    "    \"\"\"\n",
    "    对预测结果进行打分\n",
    "    :param prediction_df: 预测结果，需要包含问题，参考答案，生成的答案，列名分别为question, ref_answer, gen_answer\n",
    "    :return 打分模型原始返回结果\n",
    "    \"\"\"\n",
    "    prompt_tmpl = \"\"\"\n",
    "你是一个经济学博士，现在我有一系列问题，有一个助手已经对这些问题进行了回答，你需要参照参考答案，评价这个助手的回答是否正确，仅回复“是”或“否”即可，不要带其他描述性内容或无关信息。\n",
    "问题：\n",
    "<question>\n",
    "{{question}}\n",
    "</question>\n",
    "\n",
    "参考答案：\n",
    "<ref_answer>\n",
    "{{ref_answer}}\n",
    "</ref_answer>\n",
    "\n",
    "助手回答：\n",
    "<gen_answer>\n",
    "{{gen_answer}}\n",
    "</gen_answer>\n",
    "请评价：\n",
    "    \"\"\"\n",
    "    results = []\n",
    "\n",
    "    for _, row in tqdm(prediction_df.iterrows(), total=len(prediction_df)):\n",
    "        question = row['question']\n",
    "        ref_answer = row['ref_answer']\n",
    "        gen_answer = row['gen_answer']\n",
    "\n",
    "        prompt = prompt_tmpl.replace('{{question}}', question).replace('{{ref_answer}}', str(ref_answer)).replace('{{gen_answer}}', gen_answer).strip()\n",
    "        result = judge_llm.invoke(prompt).content\n",
    "        results.append(result)\n",
    "\n",
    "        time.sleep(1)\n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "f3dbca91-125c-450e-87ea-eedde90fe994",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-02T07:26:02.204970Z",
     "iopub.status.busy": "2024-09-02T07:26:02.204837Z",
     "iopub.status.idle": "2024-09-02T07:28:39.307055Z",
     "shell.execute_reply": "2024-09-02T07:28:39.304699Z",
     "shell.execute_reply.started": "2024-09-02T07:26:02.204958Z"
    }
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "65346a24910d4534ad2062a8473226a5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "pred_df['raw_score'] = evaluate(pred_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "a87e45ef-cd70-4731-8145-9a75c677e08a",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-02T07:28:39.310197Z",
     "iopub.status.busy": "2024-09-02T07:28:39.309486Z",
     "iopub.status.idle": "2024-09-02T07:28:39.316145Z",
     "shell.execute_reply": "2024-09-02T07:28:39.315821Z",
     "shell.execute_reply.started": "2024-09-02T07:28:39.310130Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(['是', '否',\n",
       "       '否\\n\\n（注：虽然助手的回答提供了更多的背景信息和解释，但是参考答案只提到了三个国家的名字，而没有提及关于能源进口大国或者俄乌冲突的背景。所以，从字面上看，助手的回答与参考答案并不完全一致。）'],\n",
       "      dtype=object)"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred_df['raw_score'].unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "cacde907-9350-488c-b63d-4d3a9c92c40c",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-02T07:28:39.316729Z",
     "iopub.status.busy": "2024-09-02T07:28:39.316600Z",
     "iopub.status.idle": "2024-09-02T07:28:39.329889Z",
     "shell.execute_reply": "2024-09-02T07:28:39.329357Z",
     "shell.execute_reply.started": "2024-09-02T07:28:39.316716Z"
    }
   },
   "outputs": [],
   "source": [
    "pred_df['score'] = (pred_df['raw_score'] == '是').astype(int)\n",
    "# pred_df.loc[:, 'score'] = pred_df['raw_score'].replace({'是': 1, '否': 0})\n",
    "# _ = pred_df.pop('raw_score')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "79325429-9cf1-4e2c-95ac-cb0c1a3b6156",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-09-02T07:28:39.330628Z",
     "iopub.status.busy": "2024-09-02T07:28:39.330397Z",
     "iopub.status.idle": "2024-09-02T07:28:39.336005Z",
     "shell.execute_reply": "2024-09-02T07:28:39.335633Z",
     "shell.execute_reply.started": "2024-09-02T07:28:39.330613Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.81"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred_df['score'].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5cd9c8e9-a0f4-47cf-b31d-c92462f9518d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
